cmake_minimum_required(VERSION 3.16)

set(DIR ${FL_CORE_DIR}/test)
set(LIBS flashlight ${CMAKE_DL_LIBS})

build_test(SRC ${DIR}/autograd/AutogradTest.cpp LIBS ${LIBS})
build_test(SRC ${DIR}/autograd/AutogradReductionTest.cpp LIBS ${LIBS})
build_test(SRC ${DIR}/autograd/AutogradBinaryOpsTest.cpp LIBS ${LIBS})
build_test(SRC ${DIR}/autograd/AutogradUnaryOpsTest.cpp LIBS ${LIBS})
build_test(SRC ${DIR}/autograd/AutogradNormalizationTest.cpp LIBS ${LIBS})
build_test(SRC ${DIR}/autograd/AutogradRnnTest.cpp LIBS ${LIBS})
build_test(SRC ${DIR}/autograd/AutogradConv2DTest.cpp LIBS ${LIBS})
build_test(SRC ${DIR}/common/DevicePtrTest.cpp LIBS ${LIBS})
build_test(SRC ${DIR}/common/DynamicBenchmarkTest.cpp LIBS ${LIBS})
build_test(SRC ${DIR}/common/HistogramTest.cpp LIBS ${LIBS})
build_test(SRC ${DIR}/common/LoggingTest.cpp LIBS ${LIBS})
build_test(SRC ${DIR}/common/SerializationTest.cpp LIBS ${LIBS})
build_test(SRC ${DIR}/common/UtilsTest.cpp LIBS ${LIBS})
build_test(SRC ${DIR}/optim/OptimTest.cpp LIBS ${LIBS})
build_test(SRC ${DIR}/runtime/DeviceManagerTest.cpp LIBS ${LIBS})
build_test(SRC ${DIR}/runtime/DeviceTest.cpp LIBS ${LIBS})
build_test(SRC ${DIR}/runtime/DeviceTypeTest.cpp LIBS ${LIBS})
build_test(SRC ${DIR}/nn/ModuleTest.cpp LIBS ${LIBS})
build_test(SRC ${DIR}/nn/NNSerializationTest.cpp LIBS ${LIBS})
build_test(SRC ${DIR}/nn/NNUtilsTest.cpp LIBS ${LIBS})
build_test(SRC ${DIR}/dataset/DatasetTest.cpp LIBS ${LIBS})
build_test(SRC ${DIR}/dataset/DatasetUtilsTest.cpp LIBS ${LIBS})
build_test(SRC ${DIR}/meter/MeterTest.cpp LIBS ${LIBS})
if (FL_BUILD_DISTRIBUTED)
  build_test(SRC ${DIR}/distributed/AllReduceTest.cpp LIBS ${LIBS})
endif()
if (FL_BUILD_CONTRIB)
  build_test(SRC ${DIR}/contrib/modules/ContribModuleTest.cpp LIBS ${LIBS})
  build_test(SRC ${DIR}/contrib/modules/ContribSerializationTest.cpp LIBS ${LIBS})
endif()
build_test(SRC ${DIR}/tensor/TensorBaseTest.cpp LIBS ${LIBS})
build_test(SRC ${DIR}/tensor/TensorBinaryOpsTest.cpp LIBS ${LIBS})
build_test(SRC ${DIR}/tensor/TensorBLASTest.cpp LIBS ${LIBS})
build_test(SRC ${DIR}/tensor/TensorReductionTest.cpp LIBS ${LIBS})
build_test(SRC ${DIR}/tensor/TensorUnaryOpsTest.cpp LIBS ${LIBS})
build_test(SRC ${DIR}/tensor/ComputeTest.cpp LIBS ${LIBS})
build_test(SRC ${DIR}/tensor/IndexTest.cpp LIBS ${LIBS})
build_test(SRC ${DIR}/tensor/ShapeTest.cpp LIBS ${LIBS})
if (FL_USE_CUDA)
  build_test(SRC ${DIR}/runtime/CUDADeviceTest.cpp LIBS ${LIBS})
  build_test(SRC ${DIR}/runtime/CUDAStreamTest.cpp LIBS ${LIBS})
endif()

# Tensor backend tests
set(_build_tensor_backend_tests (NOT MSVC OR NOT BUILD_SHARED_LIBS))
if (${_build_tensor_backend_tests})
  build_test(SRC ${DIR}/tensor/TensorAdapterTest.cpp LIBS ${LIBS})
  if (FL_USE_ARRAYFIRE)
    build_test(SRC ${DIR}/tensor/TensorExtensionTest.cpp LIBS ${LIBS})
    build_test(SRC ${DIR}/tensor/af/ArrayFireTensorBaseTest.cpp LIBS ${LIBS})
    build_test(SRC ${DIR}/tensor/af/CachingMemoryManagerTest.cpp LIBS ${LIBS})
    build_test(SRC ${DIR}/tensor/af/MemoryFrameworkTest.cpp LIBS ${LIBS})
    build_test(SRC ${DIR}/tensor/af/MemoryInitTest.cpp LIBS ${LIBS})
    if (FL_ARRAYFIRE_USE_CPU)
      build_test(SRC ${DIR}/tensor/af/ArrayFireCPUStreamTest.cpp LIBS ${LIBS})
    endif()
  endif()
  if (FL_USE_ONEDNN)
    build_test(SRC ${DIR}/tensor/onednn/OneDnnCPUStreamTest.cpp LIBS ${LIBS})
    build_test(SRC ${DIR}/tensor/onednn/OneDnnTensorTest.cpp LIBS ${LIBS})
  endif()
  if (FL_USE_JIT)
    build_test(SRC ${DIR}/tensor/jit/JitEvaluatorTest.cpp LIBS ${LIBS})
    build_test(SRC ${DIR}/tensor/jit/JitNodeTest.cpp LIBS ${LIBS})
    build_test(SRC ${DIR}/tensor/jit/JitScalarFoldingTest.cpp LIBS ${LIBS})
    build_test(SRC ${DIR}/tensor/jit/JitTensorTest.cpp LIBS ${LIBS})
    if (FL_USE_ONEDNN)
      build_test(SRC ${DIR}/tensor/jit/JitOneDnnOpFusionTest.cpp LIBS ${LIBS})
    endif()
  endif()
endif()
